{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 54,
   "id": "ea9e29a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn # 导入torch.nn库\n",
    "d_k = 64 # K(=Q)维度\n",
    "d_v = 64 # V维度\n",
    "# 定义缩放点积注意力类\n",
    "class ScaledDotProductAttention(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(ScaledDotProductAttention, self).__init__()        \n",
    "    def forward(self, Q, K, V, attn_mask):        \n",
    "        # Q K V [batch_size, n_heads, len_q/k/v, dim_q=k/v] (dim_q=dim_k)\n",
    "        # 计算注意力分数（原始权重）[batch_size，n_heads，len_q，len_k]\n",
    "        scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k) \n",
    "        # 使用注意力掩码，将attn_mask中值为1的位置的权重替换为极小值\n",
    "        # attn_mask [batch_size,n_heads,len_q,len_k],形状和scores相同\n",
    "        scores.masked_fill_(attn_mask, -1e9) \n",
    "        # 对注意力分数进行softmax\n",
    "        weights = nn.Softmax(dim=-1)(scores)\n",
    "        # 计算上下文向量（也就是注意力的输出）, 是上下文信息的紧凑表示\n",
    "        context = torch.matmul(weights, V)\n",
    "        return context, weights # 返回上下文向量和注意力分数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "id": "cc173797",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义多头注意力类\n",
    "d_embedding = 512  # Embedding Size\n",
    "n_heads = 8  # number of heads in Multi-Head Attention\n",
    "batch_size = 3 # 每一批数据量\n",
    "class MultiHeadAttention(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(MultiHeadAttention, self).__init__()\n",
    "        self.W_Q = nn.Linear(d_embedding, d_k * n_heads) # Q的线性变换层\n",
    "        self.W_K = nn.Linear(d_embedding, d_k * n_heads) # K的线性变换层\n",
    "        self.W_V = nn.Linear(d_embedding, d_v * n_heads) # V的线性变换层\n",
    "        self.linear = nn.Linear(n_heads * d_v, d_embedding)\n",
    "        self.layer_norm = nn.LayerNorm(d_embedding)\n",
    "\n",
    "    def forward(self, Q, K, V, attn_mask): \n",
    "        # Q K V [batch_size,len_q/k/v,embedding_dim]        \n",
    "        residual, batch_size = Q, Q.size(0) # 保留残差连接\n",
    "        # 将输入进行线性变换和重塑，以便后续处理\n",
    "        # q_s k_s v_s: [batch_size,n_heads.,len_q/k/v,d_q=k/v]\n",
    "        q_s = self.W_Q(Q).view(batch_size, -1, n_heads, d_k).transpose(1,2)        \n",
    "        k_s = self.W_K(K).view(batch_size, -1, n_heads, d_k).transpose(1,2)\n",
    "        v_s = self.W_V(V).view(batch_size, -1, n_heads, d_v).transpose(1,2)\n",
    "        # 将注意力掩码复制到多头 [batch_size,n_heads,len_q,len_k]\n",
    "        attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1)\n",
    "        # 使用缩放点积注意力计算上下文和注意力权重\n",
    "        context, weights = ScaledDotProductAttention()(q_s, k_s, v_s, attn_mask)\n",
    "        # 重塑上下文向量并进行线性变换，[batch_size，len_q，n_heads * dim_v]\n",
    "        context = context.transpose(1, 2).contiguous().view(batch_size, -1, n_heads * d_v) \n",
    "        output = self.linear(context)\n",
    "        # 与输入(Q)进行残差链接，并进行层归一化后输出[batch_size, len_q, embedding_dim]\n",
    "        output = self.layer_norm(output + residual)\n",
    "        return output, weights # 返回层归一化的输出和注意力权重"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "id": "ea0d367d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义逐位置前向传播网络类\n",
    "class PoswiseFeedForwardNet(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(PoswiseFeedForwardNet, self).__init__()\n",
    "        # 定义一维卷积层1，用于将输入映射到更高维度\n",
    "        self.conv1 = nn.Conv1d(in_channels=d_embedding, out_channels=2048, kernel_size=1)\n",
    "        # 定义一维卷积层2，用于将输入映射回原始维度\n",
    "        self.conv2 = nn.Conv1d(in_channels=2048, out_channels=d_embedding, kernel_size=1)\n",
    "        # 定义层归一化\n",
    "        self.layer_norm = nn.LayerNorm(d_embedding)\n",
    "\n",
    "    def forward(self, inputs): \n",
    "        # inputs: [batch_size, len_q, embedding_dim]        \n",
    "        residual = inputs  # 保留残差连接\n",
    "        # 在卷积层1后使用ReLU激活函数\n",
    "        output = nn.ReLU()(self.conv1(inputs.transpose(1, 2)))\n",
    "        # 使用卷积层2进行降维\n",
    "        output = self.conv2(output).transpose(1, 2)\n",
    "        # 与输入进行残差链接，并进行层归一化，[batch_size, len_q, embedding_dim]\n",
    "        output = self.layer_norm(output + residual)\n",
    "        return output # 返回层归一化后的输出加上残差连接的结果"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "id": "f1dab244",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "def get_sin_enc_table(n_position, embedding_dim):\n",
    "    # 根据位置和维度信息，初始化正弦位置编码表\n",
    "    sinusoid_table = np.zeros((n_position, embedding_dim))    \n",
    "    # 遍历所有位置和维度，计算角度值\n",
    "    for pos_i in range(n_position):\n",
    "        for hid_j in range(embedding_dim):\n",
    "            angle = pos_i / np.power(10000, 2 * (hid_j // 2) / embedding_dim)\n",
    "            sinusoid_table[pos_i, hid_j] = angle    \n",
    "    # 计算正弦和余弦值\n",
    "    sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])  # dim 2i 偶数维\n",
    "    sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])  # dim 2i+1 奇数维    \n",
    "    return torch.FloatTensor(sinusoid_table)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "id": "f1f32613",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 生成填充注意力掩码的函数，用于在多头自注意力计算中忽略填充部分\n",
    "def get_attn_pad_mask(seq_q, seq_k):\n",
    "    batch_size, len_q = seq_q.size()\n",
    "    batch_size, len_k = seq_k.size()\n",
    "    # 生成布尔类型张量[batch_size，1，len_k(=len_q)]\n",
    "    pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)  #<PAD> Token的编码值为0 \n",
    "    # 变形为何注意力分数相同形状的张量 [batch_size，len_q，len_k]\n",
    "    pad_attn_mask = pad_attn_mask.expand(batch_size, len_q, len_k) \n",
    "    return pad_attn_mask # 形状[batch_size，len_q，len_k]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "id": "c2cb5167",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 生成后续注意力掩码的函数，用于在多头自注意力计算中忽略未来信息\n",
    "def get_attn_subsequent_mask(seq):\n",
    "    # 获取输入序列的形状 [batch_size, seq_len(len_q), seq_len(len_k)]\n",
    "    attn_shape = [seq.size(0), seq.size(1), seq.size(1)]\n",
    "    # 使用numpy创建一个上三角矩阵（triu = triangle upper）\n",
    "    subsequent_mask = np.triu(np.ones(attn_shape), k=1)\n",
    "    # 将numpy数组转换为PyTorch张量，并将数据类型设置为byte（布尔值）\n",
    "    subsequent_mask = torch.from_numpy(subsequent_mask).byte()\n",
    "    return subsequent_mask # [batch_size, seq_len(len_q), seq_len(len_k)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "id": "e8bfb20e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 构建解码器层\n",
    "class DecoderLayer(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(DecoderLayer, self).__init__()\n",
    "        self.self_attn = MultiHeadAttention()  # 多头自注意力层\n",
    "        self.feed_forward = PoswiseFeedForwardNet()  # 位置前馈神经网络层\n",
    "        self.norm1 = nn.LayerNorm(d_embedding)  # 第一个层归一化\n",
    "        self.norm2 = nn.LayerNorm(d_embedding)  # 第二个层归一化\n",
    "\n",
    "    def forward(self, dec_inputs, attn_mask=None):\n",
    "        # 使用多头自注意力处理输入\n",
    "        attn_output, _ = self.self_attn(dec_inputs, dec_inputs, dec_inputs, attn_mask)\n",
    "        # 将注意力输出与输入相加并进行第一个层归一化\n",
    "        norm1_outputs = self.norm1(dec_inputs + attn_output)\n",
    "        # 将归一化后的输出输入到位置前馈神经网络\n",
    "        ff_outputs = self.feed_forward(norm1_outputs)\n",
    "        # 将前馈神经网络输出与第一次归一化后的输出相加并进行第二个层归一化\n",
    "        dec_outputs = self.norm2(norm1_outputs + ff_outputs)\n",
    "        return dec_outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "id": "33d35d08",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 构建解码器\n",
    "n_layers = 6  # 设置Encoder/Decoder的层数\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"  # 设置设备\n",
    "class Decoder(nn.Module):\n",
    "    def __init__(self, vocab_size, max_seq_len):\n",
    "        super(Decoder, self).__init__()\n",
    "        self.src_emb = nn.Embedding(vocab_size, d_embedding)  # 词嵌入层（参数为词典维度）\n",
    "        self.pos_emb = nn.Embedding(max_seq_len, d_embedding)  # 位置编码层（参数为序列长度）        \n",
    "        self.layers = nn.ModuleList([DecoderLayer() for _ in range(n_layers)]) # 初始化N个解码器层\n",
    "\n",
    "    def forward(self, dec_inputs):        \n",
    "        positions = torch.arange(len(dec_inputs), device=dec_inputs.device).unsqueeze(-1) #位置信息        \n",
    "        inputs_embedding = self.src_emb(dec_inputs) + self.pos_emb(positions) # 词嵌入与位置编码相加        \n",
    "        attn_mask = get_attn_subsequent_mask(inputs_embedding).to(device) # 生成自注意力掩码        \n",
    "        for layer in self.layers:\n",
    "            dec_outputs = layer(inputs_embedding, attn_mask) # 将输入数据传递给解码器层\n",
    "        return dec_outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "id": "82e6acd6",
   "metadata": {},
   "outputs": [],
   "source": [
    "class GPT(nn.Module):\n",
    "    def __init__(self, vocab_size, max_seq_len):\n",
    "        super(GPT, self).__init__()\n",
    "        self.decoder = Decoder(vocab_size, max_seq_len) # 解码器，用于学习文本生成能力\n",
    "        self.projection = nn.Linear(d_embedding, vocab_size)  # 全连接层，输出预测结果\n",
    "\n",
    "    def forward(self, dec_inputs):        \n",
    "        dec_outputs = self.decoder(dec_inputs) # 将输入数据传递给解码器\n",
    "        logits = self.projection(dec_outputs) # 传递给全连接层以生成预测\n",
    "        return logits #返回预测结果"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "id": "991a37e5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "词汇表大小: 28785\n",
      "词汇示例(word to index): {'<pad>': 0, '<sos>': 1, '<eos>': 2, 'the': 3, 'apple': 11505}\n"
     ]
    }
   ],
   "source": [
    "from torchtext.datasets import WikiText2 # 导入WikiText2\n",
    "from torchtext.data.utils import get_tokenizer # 导入Tokenizer分词工具\n",
    "from torchtext.vocab import build_vocab_from_iterator # 导入Vocabulary工具\n",
    "from torch.utils.data import DataLoader, Dataset # 导入Pytorch的DataLoader和Dataset\n",
    "\n",
    "tokenizer = get_tokenizer(\"basic_english\") # 定义数据预处理所需的tokenizer\n",
    "\n",
    "train_iter = WikiText2(split='train') # 加载WikiText2数据集的训练部分\n",
    "valid_iter = WikiText2(split='valid') # 加载WikiText2数据集的验证部分\n",
    "\n",
    "# 定义一个生成器函数，用于将数据集中的文本转换为tokens\n",
    "def yield_tokens(data_iter):\n",
    "    for item in data_iter:\n",
    "        yield tokenizer(item)\n",
    "\n",
    "# 创建词汇表，包括特殊tokens：\"<pad>\", \"<sos>\", \"<eos>\"\n",
    "vocab = build_vocab_from_iterator(yield_tokens(train_iter), \n",
    "                                  specials=[\"<pad>\", \"<sos>\", \"<eos>\"])\n",
    "vocab.set_default_index(vocab[\"<pad>\"])\n",
    "\n",
    "# 打印词汇表信息\n",
    "print(\"词汇表大小:\", len(vocab))\n",
    "print(\"词汇示例(word to index):\", \n",
    "      {word: vocab[word] for word in [\"<pad>\", \"<sos>\", \"<eos>\", \"the\", \"apple\"]})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "id": "e33341df",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset数据条目: 36718\n",
      "输入序列张量样例: tensor([    1,  2659,  3478, 17569,  9098])\n",
      "目标序列张量样例: tensor([ 2659,  3478, 17569,  9098,     2])\n",
      "输入序列样例文本: <sos> 96 ammunition packing boxes\n",
      "目标序列样例文本: 96 ammunition packing boxes <eos>\n"
     ]
    }
   ],
   "source": [
    "from torch.utils.data import Dataset # 导入Dataset\n",
    "max_seq_len = 256 # 设置序列的最大长度\n",
    "\n",
    "# 定义一个处理WikiText2数据集的自定义数据集类\n",
    "class WikiDataset(Dataset):\n",
    "    def __init__(self, data_iter, vocab, max_len=max_seq_len):\n",
    "        self.data = []        \n",
    "        for sentence in data_iter: # 遍历数据集，将文本转换为tokens\n",
    "            # 对每个句子进行tokenization，并截取长度为max_len-2，为<sos>和<eos>留出空间\n",
    "            tokens = tokenizer(sentence)[:max_len - 2]\n",
    "            tokens = [vocab[\"<sos>\"]] + vocab(tokens) + [vocab[\"<eos>\"]] # 添加<sos>和<eos>            \n",
    "            self.data.append(tokens) # 将处理好的tokens添加到数据集中\n",
    "    \n",
    "    def __len__(self): # 定义数据集的长度\n",
    "        return len(self.data)    \n",
    "    \n",
    "    def __getitem__(self, idx): # 定义数据集的索引方法 (即抽取数据条目)        \n",
    "        source = self.data[idx][:-1] # 获取当前数据，并将<eos>移除，作为source        \n",
    "        target = self.data[idx][1:] # 获取当前数据，并将<sos>移除，作为target（右移1位）       \n",
    "        return torch.tensor(source), torch.tensor(target) # 转换为tensor并返回\n",
    "\n",
    "train_dataset = WikiDataset(train_iter, vocab) # 创建训练数据集\n",
    "valid_dataset = WikiDataset(valid_iter, vocab) # 创建验证数据集\n",
    "print(f\"Dataset数据条目: {len(train_dataset)}\")\n",
    "sample_source, sample_target = train_dataset[100]\n",
    "print(f\"输入序列张量样例: {sample_source}\")\n",
    "print(f\"目标序列张量样例: {sample_target}\")\n",
    "decoded_source = ' '.join(vocab.lookup_tokens(sample_source.tolist()))\n",
    "decoded_target = ' '.join(vocab.lookup_tokens(sample_target.tolist()))\n",
    "print(f\"输入序列样例文本: {decoded_source}\")\n",
    "print(f\"目标序列样例文本: {decoded_target}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "id": "c61f26b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader # 导入Dataloader\n",
    "# 定义pad_sequence函数，用于将一批序列补齐到相同长度\n",
    "def pad_sequence(sequences, padding_value=0, length=None):\n",
    "    # 计算最大序列长度，如果length参数未提供，则使用输入序列中的最大长度\n",
    "    max_length = max(len(seq) for seq in sequences) if length is None else length    \n",
    "    # 创建一个具有适当形状的全零张量，用于存储补齐后的序列\n",
    "    result = torch.full((len(sequences), max_length), padding_value, dtype=torch.long)    \n",
    "    # 遍历序列，将每个序列的内容复制到结果张量中\n",
    "    for i, seq in enumerate(sequences):\n",
    "        end = len(seq)\n",
    "        result[i, :end] = seq[:end]\n",
    "    return result\n",
    "\n",
    "# 定义collate_fn函数，用于将一个批次的数据整理成适当的形状\n",
    "def collate_fn(batch):\n",
    "    # 从批次中分离源序列和目标序列\n",
    "    sources, targets = zip(*batch)    \n",
    "    # 计算批次中的最大序列长度\n",
    "    max_length = max(max(len(s) for s in sources), max(len(t) for t in targets))    \n",
    "    # 使用pad_sequence函数补齐源序列和目标序列\n",
    "    sources = pad_sequence(sources, padding_value=vocab[\"<pad>\"], length=max_length)\n",
    "    targets = pad_sequence(targets, padding_value=vocab[\"<pad>\"], length=max_length)    \n",
    "    # 返回补齐后的源序列和目标序列\n",
    "    return sources, targets\n",
    "\n",
    "# 创建一个训练数据加载器，使用自定义的collate_fn函数\n",
    "train_dataloader = DataLoader(train_dataset, batch_size=batch_size, \n",
    "                              shuffle=True, collate_fn=collate_fn)\n",
    "# 创建一个验证数据加载器，使用自定义的collate_fn函数\n",
    "valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size,\n",
    "                              shuffle=False, collate_fn=collate_fn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "id": "54dd5b39",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import torch.optim as optim  # 导入优化器\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"  # 设置设备\n",
    "model = GPT(len(vocab), max_seq_len).to(device)  # 创建GPT模型实例\n",
    "criterion = nn.CrossEntropyLoss(ignore_index=vocab[\"<pad>\"])\n",
    "optimizer = optim.Adam(model.parameters(), lr=0.0001)  # 优化器\n",
    "epochs = 2  # 训练轮次\n",
    "\n",
    "for epoch in range(epochs):\n",
    "    epoch_loss = 0\n",
    "    for batch_idx, (source, target) in enumerate(train_dataloader): # 用Dataloader加载数据\n",
    "        inputs, targets = source.to(device), target.to(device)\n",
    "        optimizer.zero_grad()  # 梯度清零\n",
    "        outputs = model(inputs)  # 获取模型输出\n",
    "        loss = criterion(outputs.view(-1, len(vocab)), targets.view(-1))  # 计算损失\n",
    "        loss.backward()  # 反向传播\n",
    "        optimizer.step()  # 更新参数\n",
    "        epoch_loss += loss.item()        \n",
    "        if (batch_idx + 1) % 500 == 0: # 每500个批次打印一次损失\n",
    "            print(f\"Batch {batch_idx + 1}/{len(train_dataloader)}, Loss: {loss.item()}\")    \n",
    "    epoch_loss /= len(train_dataloader) # 每轮打印一次损失\n",
    "    print(f\"Epoch {epoch + 1}/{epochs}, Average Loss: {epoch_loss}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "id": "1401a8fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# import time\n",
    "# from datetime import datetime\n",
    "\n",
    "# # Save the trained model\n",
    "# timestamp = datetime.fromtimestamp(time.time()).strftime('%Y-%m-%d_%H-%M-%S')\n",
    "# model_file_name = f\"trained_model_{timestamp}.pt\"\n",
    "# torch.save(model.state_dict(), model_file_name)\n",
    "# print(f\"Model saved as {model_file_name}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 91,
   "id": "5b58ecd2",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_1792588/278596197.py:16: UserWarning: masked_fill_ received a mask with dtype torch.uint8, this behavior is now deprecated,please use a mask with dtype torch.bool instead. (Triggered internally at /opt/conda/conda-bld/pytorch_1670525551200/work/aten/src/ATen/native/cuda/Indexing.cu:1435.)\n",
      "  scores.masked_fill_(attn_mask, -1e9)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "生成的文本: how are you ' species nests . common starlings are common starlings are common starlings are common starlings are common starlings are common starlings are common starlings are common starlings are common starlings are common starlings are common starlings are common starlings are common starlings are common starlings are common starlings are common\n"
     ]
    }
   ],
   "source": [
    "# Replace 'model_timestamp.pt' with your saved model's filename\n",
    "model.load_state_dict(torch.load('trained_model_2023-05-05_14-08-24.pt'))\n",
    "# 测试文本生成\n",
    "def generate_text_greedy_search(model, input_str, max_len=50):\n",
    "    model.eval()  # 将模型设置为评估（测试）模式，关闭dropout和batch normalization等训练相关的层\n",
    "    # 将输入字符串中的每个Token 转换为其在词汇表中的索引\n",
    "    input_tokens = [vocab[token] for token in input_str.split()]\n",
    "    # 创建一个新列表，将输入的Token复制到输出Token中,目前只有输入的词\n",
    "    output_tokens = input_tokens.copy()\n",
    "    with torch.no_grad():  # 禁用梯度计算，以节省内存并加速测试过程\n",
    "        for _ in range(max_len):  # 生成最多max_len个Token\n",
    "            # 将输出token转换为 PyTorch张量，并增加一个代表批次的维度[1, len(output_tokens)]\n",
    "            inputs = torch.LongTensor(output_tokens).unsqueeze(0).to(device)\n",
    "            outputs = model(inputs) # 输出 logits形状为[1, len(output_tokens), vocab_size]\n",
    "            logits = outputs[:, -1, :] # 只关心最后一个时间步（即最新生成的token）的logits\n",
    "            # 在最后一个维度上获取logits中的最大值，并返回其索引（即下一个Token）\n",
    "            _, next_token = torch.max(logits, dim=-1)            \n",
    "            next_token = next_token.item() # 将张量转换为Python整数            \n",
    "            if next_token == vocab[\"<eos>\"]:\n",
    "                break # 如果生成的Token是 EOS（结束符），则停止生成过程           \n",
    "            output_tokens.append(next_token) # 将生成的Token添加到output_tokens列表\n",
    "    # 将输出Token转换回文本字符串\n",
    "    output_str = \" \".join([vocab.get_itos()[token] for token in output_tokens\n",
    "                           if vocab.get_itos()[token] != \"<pad>\" and vocab.get_itos()[token] != \"<unk>\" ])\n",
    "    return output_str\n",
    "\n",
    "input_str = \"how are you\" # 输入一个词：Python\n",
    "generated_text = generate_text_greedy_search(model, input_str) # 模型跟着这个字生成后续文本\n",
    "print(\"生成的文本:\", generated_text) # 打印预测文本"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "id": "eb185470",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_1792588/278596197.py:16: UserWarning: masked_fill_ received a mask with dtype torch.uint8, this behavior is now deprecated,please use a mask with dtype torch.bool instead. (Triggered internally at /opt/conda/conda-bld/pytorch_1670525551200/work/aten/src/ATen/native/cuda/Indexing.cu:1435.)\n",
      "  scores.masked_fill_(attn_mask, -1e9)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "生成的文本: my name was also used in 1897 by lucasfilm games in the common by lucasfilm games in the common by lucasfilm games in the common by lucasfilm games in the common by lucasfilm games in the common by lucasfilm games in the common by lucasfilm games in the common by lucasfilm games\n"
     ]
    }
   ],
   "source": [
    "# 定义集束搜索的函数\n",
    "def generate_text_beam_search(model, input_str, max_len=50, beam_width=5):\n",
    "    model.eval()  # 将模型设置为评估（测试）模式，关闭dropout和batch normalization等训练相关的层\n",
    "    # 将输入字符串中的每个token 转换为其在词汇表中的索引\n",
    "    input_tokens = [vocab[token] for token in input_str.split()]\n",
    "    # 创建一个列表，用于存储候选序列\n",
    "    candidates = [(input_tokens, 0.0)]\n",
    "    with torch.no_grad():  # 禁用梯度计算，以节省内存并加速测试过程\n",
    "        for _ in range(max_len):  # 生成最多max_len个tokens\n",
    "            new_candidates = []\n",
    "            for candidate, candidate_score in candidates:\n",
    "                inputs = torch.LongTensor(candidate).unsqueeze(0).to(device)\n",
    "                outputs = model(inputs) # 输出 logits形状为[1, len(output_tokens), vocab_size]\n",
    "                logits = outputs[:, -1, :] # 只关心最后一个时间步（即最新生成的token）的logits\n",
    "                # 找到具有最高分数的前beam_width个tokens\n",
    "                scores, next_tokens = torch.topk(logits, beam_width, dim=-1)\n",
    "                final_results = [] # 初始化输出序列\n",
    "                for score, next_token in zip(scores.squeeze(), next_tokens.squeeze()):\n",
    "                    new_candidate = candidate + [next_token.item()]\n",
    "                    new_score = candidate_score - score.item()  # 使用负数，因为我们需要降序排列\n",
    "                    if next_token.item() == vocab[\"<eos>\"]:\n",
    "                        # 如果生成的token是EOS（结束符），将其添加到最终结果中\n",
    "                        final_results.append((new_candidate, new_score))\n",
    "                    else:\n",
    "                        # 将新生成的候选序列添加到新候选列表中\n",
    "                        new_candidates.append((new_candidate, new_score))\n",
    "            # 从新候选列表中选择得分最高的beam_width个序列\n",
    "            candidates = sorted(new_candidates, key=lambda x: x[1])[:beam_width]\n",
    "    # 选择得分最高的候选序列\n",
    "    best_candidate, _ = sorted(candidates, key=lambda x: x[1])[0]\n",
    "    # 将输出 token 转换回文本字符串\n",
    "    output_str = \" \".join([vocab.get_itos()[token] for token in best_candidate if vocab.get_itos()[token] != \"<pad>\"])\n",
    "    return output_str\n",
    "\n",
    "model.load_state_dict(torch.load('trained_model_2023-05-05_14-08-24.pt')) # 加载模型\n",
    "input_str = \"my name\"  # 输入几个词\n",
    "generated_text = generate_text_beam_search(model, input_str)  # 模型跟着这些词生成后续文本\n",
    "print(\"生成的文本:\", generated_text)  # 打印生成的文本"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8207f043",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af42a2fe",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
